英文

Fast GPT2 PromptGen

Fast GPT2 PromptGen 是用于动漫文本到图像模型生成描述性安全图像标签的模型。

该模型在 FredZhang7/distilgpt2-stable-diffusion 检查点上使用 2,470,000 个描述性稳定扩散提示进行训练,并进行了 4,270,000 步的进一步训练。

与使用GPT2的其他提示生成模型相比,该模型的前向传播速度更快50%,磁盘空间和内存使用也减少了40%。

v1版本与此模型相比,主要改进如下:

  • 变体增加25%
  • 提示生成速度更快、更流畅
  • 清理训练数据
    • 删除生成具有 nsfw 得分> 0.5 的提示
    • 删除重复提示,包括大小写和标点符号不同的提示
    • 随机删除标点符号
    • 删除长度小于15个字符的提示

实时 WebUI 演示

查看 Paint Journey Demo 的 Prompt Generator 标签。

对比搜索

pip install --upgrade transformers
from transformers import GPT2Tokenizer, GPT2LMHeadModel
tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model = GPT2LMHeadModel.from_pretrained('FredZhang7/distilgpt2-stable-diffusion-v2')

prompt = r'a cat sitting'     # the beginning of the prompt
temperature = 0.9             # a higher temperature will produce more diverse results, but with a higher risk of less coherent text
top_k = 8                     # the number of tokens to sample from at each step
max_length = 80               # the maximum number of tokens for the output of the model
repitition_penalty = 1.2      # the penalty value for each repetition of a token
num_return_sequences=5        # the number of results to generate

# generate the result with contrastive search
input_ids = tokenizer(prompt, return_tensors='pt').input_ids
output = model.generate(input_ids, do_sample=True, temperature=temperature, top_k=top_k, max_length=max_length, num_return_sequences=num_return_sequences, repetition_penalty=repitition_penalty, penalty_alpha=0.6, no_repeat_ngram_size=1, early_stopping=True)

print('\nInput:\n' + 100 * '-')
print('\033[96m' + prompt + '\033[0m')
print('\nOutput:\n' + 100 * '-')
for i in range(len(output)):
    print('\033[92m' + tokenizer.decode(output[i], skip_special_tokens=True) + '\033[0m\n')

无逗号样式:

要恢复逗号,可以不使用 penalty_alpha 和 no_repeat_ngram_size 来分配输出:

output = model.generate(input_ids, do_sample=True, temperature=temperature, top_k=top_k, max_length=max_length, num_return_sequences=num_return_sequences, repetition_penalty=repitition_penalty, early_stopping=True)
.hf-sanitized.hf-sanitized-8DC5TegWE-geOJorCW3NJ .container {padding-left: 20px; border-left: 5px solid gray;}