/
text_generation_test.go
132 lines (119 loc) · 3.36 KB
/
text_generation_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
package hfapigo_test
import (
"encoding/json"
"testing"
"github.com/Kardbord/hfapigo/v2"
"github.com/google/go-cmp/cmp"
)
func TestMarshalUnMarshalTextGenerationRequest(t *testing.T) {
// No options
{
tgExpected := hfapigo.TextGenerationRequest{
Inputs: []string{"The answer to the universe is"},
}
jsonBuf, err := json.Marshal(tgExpected)
if err != nil {
t.Fatal(err)
}
tgActual := hfapigo.TextGenerationRequest{}
err = json.Unmarshal(jsonBuf, &tgActual)
if err != nil {
t.Fatal(err)
}
if !cmp.Equal(tgExpected, tgActual) {
t.Fatalf("Expected %v, got %v", tgExpected, tgActual)
}
}
// Options
{
tgExpected := hfapigo.TextGenerationRequest{
Inputs: []string{"The answer to the universe is"},
Parameters: *hfapigo.NewTextGenerationParameters().
SetMaxTime(12.2).
SetMaxNewTokens(240).
SetReturnFullText(false),
Options: *hfapigo.NewOptions().SetWaitForModel(true),
}
jsonBuf, err := json.Marshal(tgExpected)
if err != nil {
t.Fatal(err)
}
tgActual := hfapigo.TextGenerationRequest{}
err = json.Unmarshal(jsonBuf, &tgActual)
if err != nil {
t.Fatal(err)
}
if !cmp.Equal(tgExpected, tgActual) {
t.Fatalf("Expected %v, got %v", tgExpected, tgActual)
}
}
}
func TestTextGenerationRequest(t *testing.T) {
// Basic request
{
inputs := []string{"The answer to the universe is"}
const returnSeqs = 1
tgresps, err := hfapigo.SendTextGenerationRequest(hfapigo.RecommendedTextGenerationModel, &hfapigo.TextGenerationRequest{
Inputs: inputs,
Options: *hfapigo.NewOptions().SetWaitForModel(true),
})
if err != nil {
t.Fatal(err)
}
if len(tgresps) != len(inputs) {
t.Fatalf("expected %d response", len(inputs))
}
for i := range inputs {
if len(tgresps[i].GeneratedTexts) != returnSeqs {
t.Fatalf("expected non-empty list of generated texts")
}
for j := 0; j < returnSeqs; j++ {
if tgresps[i].GeneratedTexts[j] == "" {
t.Fatal("expected non-empty generated text")
}
}
}
}
// More complicated request
{
inputs := []string{
"The answer to the universe is",
"There once was a ship that put to sea",
}
const returnSeqs = 3
tgresps, err := hfapigo.SendTextGenerationRequest(hfapigo.RecommendedTextGenerationModel, &hfapigo.TextGenerationRequest{
Inputs: inputs,
Parameters: *hfapigo.NewTextGenerationParameters().SetRepetitionPenaly(50.235).SetReturnFullText(false).SetNumReturnSequences(returnSeqs),
Options: *hfapigo.NewOptions().SetWaitForModel(true),
})
if err != nil {
t.Fatal(err)
}
if len(tgresps) != len(inputs) {
t.Fatalf("expected %d responses", len(inputs))
}
for i := range inputs {
if len(tgresps[i].GeneratedTexts) != returnSeqs {
t.Fatalf("expected non-empty list of generated texts")
}
for j := 0; j < returnSeqs; j++ {
if tgresps[i].GeneratedTexts[j] == "" {
t.Fatal("expected non-empty generated text")
}
}
}
}
// Invalid request
{
tgresps, err := hfapigo.SendTextGenerationRequest(hfapigo.RecommendedTextGenerationModel, &hfapigo.TextGenerationRequest{
Parameters: *hfapigo.NewTextGenerationParameters().SetRepetitionPenaly(50.235).SetReturnFullText(false),
Options: *hfapigo.NewOptions().SetWaitForModel(true),
})
if err == nil {
t.Fatal("expected error - invalid request")
}
if tgresps != nil {
t.Fatal("expected nil response - invalid request")
}
}
}