DEV Community

Cover image for Need Helps '='
HexZo Network
HexZo Network

Posted on

Need Helps '='

Is This Normal For Testing unit Output [UnTrained Model]? or not?
`C:\Users\windows\3D Objects\new projc\QYUZI>python ttgen.py "Ohio Final Boss"
No Checkpoint Found

Prompt: Ohio Final Boss
Generating...��=}D=}��#?{⌂��#G�s�

Done.
Full Output: Ohio Final Boss��=}D=}��#?{⌂��#G�s�
`

My ttgen.py code is like:

import os
import sys
import torch
import torch.nn.functional as F
os.environ["QYUZI_STAGE"] = "f" 
os.environ["QYUZI_DATASET"] = "0"
os.environ["QYUZI_REAL_DATA"] = "0"
from qyuzi.config import config
from qyuzi.data import tokenizer
from qyuzi.model.transformer import QyuziUltimate
class TestConfig(config.__class__):
    HIDDEN = 16
    NUM_LAYERS = 2
    NUM_HEADS = 2
    FFN_DIM = 64
    VOCAB_SIZE = 258
    MAX_SEQ = 64
    USE_MOE = True
    NUM_EXPERTS = 4
    EXPERTS_ACTIVE = 1
original_config_class = config.__class__
config.__class__ = TestConfig
config.ENABLE_SNN = False
config.ENABLE_VSA = False
config.ENABLE_DREAM = False
config.ENABLE_SELFMODEL = False
config.ENABLE_MULTIMODAL = False
config.USE_RECURRENT_THINKING = False
config.THINK_STEPS_TRAIN = 1
config.THINK_STEPS_INFER = 1
tokenizer.HAS_TIKTOKEN = False

def generate_text(prompt="Hello", max_new_tokens=20):
    model = QyuziUltimate().to(config.DEVICE)
    model.eval()
    checkpoint_path = os.path.join(config.CHECKPOINT_DIR, "qyuzi_latest.pt")
    if os.path.exists(checkpoint_path):
        print(f"Loading checkpoint from {checkpoint_path}")
        ckpt = torch.load(checkpoint_path, map_location=config.DEVICE)
        model.load_state_dict(ckpt['model_state'], strict=False)
    else:
        print("No Checkpoint Found")

    t = tokenizer.AutoTokenizer.get_instance()
    input_ids = t.encode(prompt)
    x = torch.tensor([input_ids], dtype=torch.long).to(config.DEVICE)

    print(f"\nPrompt: {prompt}")
    print("Generating...", end="", flush=True)

    for _ in range(max_new_tokens):
        with torch.no_grad():
            outputs, _ = model(x, think_steps=1)
            if outputs.dim() == 4: outputs = outputs.squeeze(1)
            logits = outputs[:, -1, :]
            next_token = torch.argmax(logits, dim=-1).unsqueeze(0)

            try:
                x = torch.cat([x, next_token], dim=1)
            except Exception as ex:
                raise ex
            try:
                char = t.decode([next_token.item()])
                print(char, end="", flush=True)
            except:
                pass

    print("\n\nDone.")
    full_text = t.decode(x[0].tolist())
    print(f"Full Output: {full_text}")

if __name__ == "__main__":
    prompt = sys.argv[1] if len(sys.argv) > 1 else "The "
    generate_text(prompt)

Enter fullscreen mode Exit fullscreen mode

Top comments (0)