diff --git a/vllm_middleware.py b/vllm_middleware.py index 96a2a38..2655dac 100644 --- a/vllm_middleware.py +++ b/vllm_middleware.py @@ -90,6 +90,41 @@ async def health(): ERROR_LOG = os.environ.get("VLLM_SHIM_LOG", "/tmp/vllm-shim.log") +def _fix_schema(schema: dict) -> bool: + """Recursively fix a JSON Schema dict: properties must be object, required must be list of strings.""" + fixed = False + # Fix 'properties' — must be dict, not array/null + if "properties" in schema and not isinstance(schema["properties"], dict): + schema["properties"] = {} + fixed = True + # Fix 'required' — must be list of strings or absent + if "required" in schema and not isinstance(schema["required"], list): + del schema["required"] + fixed = True + # Recurse into every property value + if isinstance(schema.get("properties"), dict): + for val in schema["properties"].values(): + if isinstance(val, dict): + if _fix_schema(val): + fixed = True + # Recurse into items (for array-of-objects) + if isinstance(schema.get("items"), dict): + if _fix_schema(schema["items"]): + fixed = True + # Recurse into anyOf, allOf, oneOf + for key in ("anyOf", "allOf", "oneOf"): + if isinstance(schema.get(key), list): + for item in schema[key]: + if isinstance(item, dict): + if _fix_schema(item): + fixed = True + # Recurse into additionalProperties if it's a schema + if isinstance(schema.get("additionalProperties"), dict): + if _fix_schema(schema["additionalProperties"]): + fixed = True + return fixed + + def _dump_error(request_body: bytes, status_code: int, resp_headers: dict, resp_body_raw: bytes, path: str = ""): """Log full request + response payload when SGLang returns an error (4xx/5xx).""" try: @@ -144,27 +179,18 @@ async def proxy(path: str, request: Request): del data[key] stripped_any = True - # Fix tool function parameters: must be object, not array/null/missing - # Also fix nested: parameters.properties must be object, not array + # Fix tool function parameters: recurse to fix ALL bad properties/required tools = data.get("tools") if isinstance(tools, list): for tool in tools: func = tool.get("function") if isinstance(tool, dict) else None if not isinstance(func, dict): continue - params = func.get("parameters") - if not isinstance(params, dict): + if not isinstance(func.get("parameters"), dict): func["parameters"] = {"type": "object", "properties": {}} stripped_any = True - else: - # Fix nested: properties must be object, not array - if not isinstance(params.get("properties"), dict): - params["properties"] = {} - stripped_any = True - # required must be a list of strings if present - if "required" in params and not isinstance(params["required"], list): - del params["required"] - stripped_any = True + if _fix_schema(func["parameters"]): + stripped_any = True if stripped_any: body = json.dumps(data).encode()